import argparse

from Data.CHILD.CHILD import child_load
from Data.SACHS.SACHS import sachs_load
from Data.SyntheticData import graph_generate
from Data.utils import Continuous2Discrete, Part2Discrete
from GES.GES import ges
from utils import F1_score_nparray, SHD_nparray


def main():
    parser = argparse.ArgumentParser(description="parameter setting")
    parser.add_argument("--dataset", type=str, default="con",
                        help="dataset name, 'con', 'dis', 'mix', 'child' and 'sach'. ")
    parser.add_argument("--score", type=str, default='Ours',
                        help="score function option, 'CV', 'Marg', 'Ours' ")
    parser.add_argument("--device", type=str, default="cuda", help='training on "cpu" or "cuda" ')
    parser.add_argument("--nums", type=int, default=500, help="sample size")
    parser.add_argument("--seed", type=int, default=1111, help="random seed")
    parser.add_argument("--gd", type=float, default=0.5, help="graph density")
    opt = parser.parse_args()

    if opt.dataset in ['con', 'dis', 'mix']:
        Data_dir = graph_generate(opt.nums, opt.gd, seeds=opt.seed)
        if opt.dataset == 'dis':
            Data_dir['data_mat'] =  Continuous2Discrete(Data_dir)
        if opt.dataset == 'mix':
            Data_dir['data_mat'] =  Part2Discrete(Data_dir)
    elif opt.dataset == 'sach':
        Data_dir = sachs_load(nums=opt.nums, seeds=opt.seed)
    elif opt.dataset == 'child':
        Data_dir = child_load(nums=opt.nums, seeds=opt.seed)
    else:
        raise NotImplementedError("Unknown dataset")

    parameters = {'epochs': 500, 'device': opt.device}
    print(parameters)

    Gt = Data_dir['G']
    print("truth graph")
    print(Gt)
    assert opt.score in ['CV', 'Marg', 'Ours']
    if opt.score == 'CV':
        Record = ges(Data_dir, 'local_score_CV_general', parameters=parameters)
    elif opt.score == 'Marg':
        Record = ges(Data_dir, 'local_score_marginal_general',  parameters=parameters)
    elif opt.score == 'Ours':
        Record = ges(Data_dir, 'local_score_cat',  parameters=parameters)

    # Visualization
    # pyd = GraphUtils.to_pydot(Record['G'])
    # tmp_png = pyd.create_png(f="png")
    # fp = io.BytesIO(tmp_png)
    # img = mpimg.imread(fp, format='png')
    # plt.axis('off')
    # plt.imshow(img)
    # plt.show()

    print("cat est graph")
    print(Record['G'].graph)
    F1_score = F1_score_nparray(Gt, Record['G'].graph)
    SHD = SHD_nparray(Gt, Record['G'].graph)
    print("CAT --> F1_score: ", F1_score, "SHD: ", SHD)
    return F1_score, SHD, Record['G'].graph

if __name__ == '__main__':
    main()